"""
measure_wilson.py
~~~~~~~~~~~~~~~~~~

Compute Wilson loops on a two‑dimensional lattice for one or more gauge
groups.  For each loop size ``L`` the average value of the loop operator is
measured by sliding an ``L×L`` rectangular contour over the lattice and
multiplying the link variables along its edges.  For non‑Abelian gauge
groups the trace of the resulting product matrix is taken; for U(1) the
matrix is one‑dimensional and the trace reduces to the complex number itself.

This module exposes two key helpers:

* ``build_link_index_map`` constructs a mapping from a lattice of links
  produced by a square lattice builder to a compact integer index.  The
  expected lattice representation is a list of tuples ``((x, y), μ)``
  returned by an external ``build_lattice`` function.
* ``compute_wilson_loop_average`` computes the average value of a Wilson
  loop of a given size on a periodic square lattice for a specified gauge
  group.  It samples random starting positions to approximate the full
  average.

This file intentionally omits the CLI entry point present in the original
repository, as the FPH integration uses bespoke orchestration.
"""

from __future__ import annotations

import numpy as np
from typing import Dict, Tuple, List


def build_link_index_map(lattice: np.ndarray) -> Dict[Tuple[int, int, int], int]:
    """Construct a mapping from ``(x, y, μ)`` to the index in the lattice array.

    Parameters
    ----------
    lattice : np.ndarray
        Array of link descriptors returned by ``build_lattice``.  Each
        element is a tuple ``((x, y), μ)``.

    Returns
    -------
    dict
        Mapping from ``(x, y, μ)`` to the integer index of that link.
    """
    mapping: Dict[Tuple[int, int, int], int] = {}
    for idx, link in enumerate(lattice):
        (coords, mu) = link
        x, y = coords
        mapping[(int(x), int(y), int(mu))] = idx
    return mapping


def compute_wilson_loop_average(
    U: np.ndarray,
    mapping: Dict[Tuple[int, int, int], int],
    lattice_size: int,
    loop_size: int,
    gauge_group: str,
    n_samples: int = 5000,
) -> complex:
    """Compute the average Wilson loop of a given size for a gauge group.

    Parameters
    ----------
    U : np.ndarray
        Link variables for the gauge group.  For U(1) this is a one‑dimensional
        complex array of length ``num_links``.  For SU(2) and SU(3) it is an
        array of shape ``(num_links, N, N)`` with ``N`` the dimension of the
        representation (2 or 3).  All matrices are assumed to be diagonal.
    mapping : dict
        Mapping from ``(x, y, μ)`` to link index.
    lattice_size : int
        Number of sites along each dimension of the square lattice.
    loop_size : int
        Size of the square loop (number of links along each side).
    gauge_group : str
        One of ``'U1'``, ``'SU2'``, ``'SU3'`` determining how traces and
        inverses are computed.
    n_samples : int, optional
        Number of random starting positions to sample for averaging. Default is 5000.

    Returns
    -------
    complex
        The average value of the Wilson loop operator for the given gauge and
        loop size.
    """
    is_u1 = (gauge_group.upper() == 'U1')
    L = loop_size
    N = lattice_size

    # Generate random starting positions
    x_starts = np.random.randint(0, N, n_samples)
    y_starts = np.random.randint(0, N, n_samples)

    total = 0.0 + 0.0j

    for i in range(n_samples):
        x = x_starts[i]
        y = y_starts[i]

        # Forward along +x
        product = 1.0 + 0.0j if is_u1 else np.eye(U.shape[1], dtype=complex)
        for s in range(L):
            x1 = (x + s) % N
            y1 = y
            idx = mapping[(x1, y1, 0)]
            val = U[idx]
            if is_u1:
                product *= val
            else:
                product = product @ val

        # Forward along +y
        for s in range(L):
            x1 = (x + L) % N
            y1 = (y + s) % N
            idx = mapping[(x1, y1, 1)]
            val = U[idx]
            if is_u1:
                product *= val
            else:
                product = product @ val

        # Backward along -x
        for s in range(L):
            x1 = (x + L - s) % N
            y1 = (y + L) % N
            idx = mapping[(x1, y1, 0)]
            val = U[idx]
            if is_u1:
                product *= np.conjugate(val)
            else:
                product = product @ val.conj().T

        # Backward along -y
        for s in range(L):
            x1 = x
            y1 = (y + L - s) % N
            idx = mapping[(x1, y1, 1)]
            val = U[idx]
            if is_u1:
                product *= np.conjugate(val)
            else:
                product = product @ val.conj().T

        if is_u1:
            total += product
        else:
            total += np.trace(product)

    return total / n_samples


__all__ = [
    'build_link_index_map',
    'compute_wilson_loop_average',
]